from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def recode_yes_no(series: pd.Series) -> pd.Series:
    s = clean_numeric(series)
    return s.map({1: 1.0, 2: 0.0})


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


@dataclass(frozen=True)
class Spec:
    name: str
    controls_cont: list[str]
    controls_cat: list[str]


def build_dataset(root: Path) -> pd.DataFrame:
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    reg = pd.read_csv(root / "data_external" / "broadband_region_year_eurostat_isoc_r_broad_h.csv")

    # Keep NUTS1/2 where regional matching is meaningful
    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()

    reg["region"] = reg["region"].astype(str).str.upper().str.strip()
    reg["year"] = pd.to_numeric(reg["year"], errors="coerce").astype("Int64")
    reg["broadband_pc_hh"] = pd.to_numeric(reg.get("broadband_pc_hh"), errors="coerce")
    reg["internet_access_pc_hh"] = pd.to_numeric(reg.get("internet_access_pc_hh"), errors="coerce")

    df = ess.merge(reg, left_on=["region", "survey_year"], right_on=["region", "year"], how="left")

    # Endogenous internet use
    df["netusoft"] = clean_numeric(df["netusoft"])

    # Instruments (region-year)
    df["z_broadband_pc_hh"] = zscore(df["broadband_pc_hh"])
    df["z_internet_access_pc_hh"] = zscore(df["internet_access_pc_hh"])

    # Outcomes: main
    df["y_part_index"] = np.nan  # placeholder; filled below
    df["y_vote"] = recode_yes_no(df["vote"])
    df["y_sgnptit"] = recode_yes_no(df["sgnptit"])
    df["y_pbldmn"] = recode_yes_no(df["pbldmn"])
    df["y_bctprd"] = recode_yes_no(df["bctprd"])
    df["y_contplt"] = recode_yes_no(df["contplt"])
    df["y_badge"] = recode_yes_no(df["badge"])
    df["y_pstplonl"] = recode_yes_no(df["pstplonl"]) if "pstplonl" in df.columns else np.nan
    part_cols = ["y_sgnptit", "y_pbldmn", "y_bctprd", "y_contplt", "y_badge", "y_pstplonl"]
    df["y_part_index"] = df[part_cols].mean(axis=1, skipna=True)
    df.loc[df[part_cols].isna().all(axis=1), "y_part_index"] = np.nan

    # Mechanism candidate
    df["nwspol"] = clean_numeric(df["nwspol"])

    # Placebo outcomes
    df["gndr_num"] = clean_numeric(df["gndr"])
    df["brncntr_bin"] = recode_yes_no(df["brncntr"])

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["gndr"] = df["gndr_num"]
    df["eduyrs"] = clean_numeric(df["eduyrs"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])
    df["uempla"] = clean_numeric(df["uempla"])
    df["domicil"] = clean_numeric(df["domicil"])

    # Education indicator
    df["edu_high"] = np.where(df["eduyrs"].notna() & (df["eduyrs"] > 12), 1.0, np.nan)
    df.loc[df["eduyrs"].notna() & (df["eduyrs"] <= 12), "edu_high"] = 0.0

    # FE key (country×survey_year)
    df["ct"] = df["cntry"].astype(str) + "_" + df["survey_year"].astype("Int64").astype(str)

    # Keep rows where instruments exist and edu_high exists
    df = df[df[["z_broadband_pc_hh", "z_internet_access_pc_hh"]].notna().any(axis=1)].copy()
    df = df[df["edu_high"].notna()].copy()

    return df


def build_exog(d: pd.DataFrame, controls_cont: list[str], controls_cat: list[str]) -> pd.DataFrame:
    exog_parts = [pd.DataFrame({"const": 1.0}, index=d.index)]

    if controls_cont:
        exog_parts.append(d[controls_cont])

    if controls_cat:
        cat = d[controls_cat].copy()
        for c in controls_cat:
            cat[c] = cat[c].astype("Int64").astype(str)
        exog_parts.append(pd.get_dummies(cat, prefix=controls_cat, drop_first=True))

    # Always include edu_high main effect
    exog_parts.append(d[["edu_high"]])

    # ctFE
    ct = d["ct"].astype(str)
    exog_parts.append(pd.get_dummies(ct, prefix="ct", drop_first=True))

    exog = pd.concat(exog_parts, axis=1)
    exog = exog.loc[:, ~exog.columns.duplicated()]
    # Drop constant columns beyond intercept (can happen in filtered subsamples)
    nunique = exog.nunique(dropna=False)
    keep = (nunique > 1) | (exog.columns == "const")
    exog = exog.loc[:, keep]
    return exog


def fit_iv_interaction(df: pd.DataFrame, y: str, spec: Spec):
    cols = [
        y,
        "netusoft",
        "edu_high",
        "z_broadband_pc_hh",
        "z_internet_access_pc_hh",
        "ct",
        "region",
        *spec.controls_cont,
        *spec.controls_cat,
    ]
    d = df[cols].dropna().copy()

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    d["z_broadband_x_edu"] = d["z_broadband_pc_hh"] * d["edu_high"]
    d["z_inetacc_x_edu"] = d["z_internet_access_pc_hh"] * d["edu_high"]

    exog = build_exog(d, spec.controls_cont, spec.controls_cat)
    endog = d[["netusoft", "netusoft_x_edu"]]
    instr = d[["z_broadband_pc_hh", "z_internet_access_pc_hh", "z_broadband_x_edu", "z_inetacc_x_edu"]]

    res = IV2SLS(d[y], exog=exog, endog=endog, instruments=instr).fit(cov_type="clustered", clusters=d["region"])
    return res


def extract(res, outcome: str, label: str, spec_name: str) -> dict:
    b_low = float(res.params.get("netusoft", np.nan))
    b_delta = float(res.params.get("netusoft_x_edu", np.nan))
    se_low = float(res.std_errors.get("netusoft", np.nan))
    se_delta = float(res.std_errors.get("netusoft_x_edu", np.nan))

    diag = res.first_stage.diagnostics
    fs_net = diag.loc["netusoft"] if "netusoft" in diag.index else diag.iloc[0]
    fs_int = diag.loc["netusoft_x_edu"] if "netusoft_x_edu" in diag.index else diag.iloc[-1]

    return {
        "spec": spec_name,
        "outcome": outcome,
        "label": label,
        "nobs": int(res.nobs),
        "coef_lowEdu": b_low,
        "se_lowEdu": se_low,
        "coef_highEdu": b_low + b_delta,
        "coef_deltaHighMinusLow": b_delta,
        "se_deltaHighMinusLow": se_delta,
        "fsF_netusoft": float(fs_net.get("f.stat", np.nan)),
        "fsF_netusoft_x_edu": float(fs_int.get("f.stat", np.nan)),
    }


def to_md(df: pd.DataFrame) -> str:
    d = df.copy()
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    return d.to_markdown(index=False)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)

    df = build_dataset(root)

    specs = [
        Spec(name="minimal", controls_cont=["agea", "gndr", "eduyrs", "hinctnta"], controls_cat=[]),
        Spec(
            name="expanded",
            controls_cont=["agea", "gndr", "eduyrs", "hinctnta", "uempla"],
            controls_cat=["domicil"],
        ),
        # Expanded+born-in-country as continuous control (binary)
        Spec(
            name="expanded_plus_origin",
            controls_cont=["agea", "gndr", "eduyrs", "hinctnta", "uempla", "brncntr_bin"],
            controls_cat=["domicil"],
        ),
    ]

    outcomes = [
        ("y_part_index", "Participation index"),
        ("y_vote", "Voted"),
        ("y_sgnptit", "Signed petition"),
        ("y_contplt", "Contacted politician"),
        ("y_bctprd", "Boycott"),
        ("y_pstplonl", "Posted politics online"),
    ]

    placebos = [
        ("gndr", "Placebo: gender code"),
        ("brncntr_bin", "Placebo: born in country"),
    ]

    rows = []
    errors = []

    for spec in specs:
        for y, label in outcomes + placebos:
            if y not in df.columns:
                continue
            # Avoid controlling for the same variable as outcome
            spec_use = spec
            if y in spec.controls_cont or y in spec.controls_cat:
                cc = [c for c in spec.controls_cont if c != y]
                cat = [c for c in spec.controls_cat if c != y]
                spec_use = Spec(name=spec.name, controls_cont=cc, controls_cat=cat)
            try:
                res = fit_iv_interaction(df, y, spec_use)
                rows.append(extract(res, y, label, spec.name))
            except Exception as e:
                errors.append(f"[{spec.name}] {y}: {type(e).__name__}: {e}")

    out_csv = out_dir / "iv_interaction_controls_robustness.csv"
    out_md_main = out_dir / "iv_interaction_controls_robustness_main.md"
    out_md_placebo = out_dir / "iv_interaction_controls_robustness_placebo.md"
    out_txt = out_dir / "iv_interaction_controls_robustness_notes.txt"

    res_df = pd.DataFrame(rows).sort_values(["label", "spec"]).reset_index(drop=True)
    res_df.to_csv(out_csv, index=False, encoding="utf-8")

    main_df = res_df[~res_df["label"].str.startswith("Placebo:", na=False)].copy()
    placebo_df = res_df[res_df["label"].str.startswith("Placebo:", na=False)].copy()
    out_md_main.write_text(to_md(main_df), encoding="utf-8")
    out_md_placebo.write_text(to_md(placebo_df), encoding="utf-8")

    notes = []
    notes.append("Interaction-IV robustness: minimal vs expanded controls (ctFE; clustered by region)")
    notes.append("Specs:")
    notes.append("- minimal: age, gender, education years, income")
    notes.append("- expanded: + unemployment (uempla) + domicil dummies")
    notes.append("- expanded_plus_origin: + born-in-country binary (brncntr_bin)")
    notes.append("")
    if errors:
        notes.append("Errors:")
        notes.extend(errors)
        notes.append("")
    notes.append(f"Wrote: {out_csv}")
    notes.append(f"Wrote: {out_md_main}")
    notes.append(f"Wrote: {out_md_placebo}")
    out_txt.write_text("\n".join(notes) + "\n", encoding="utf-8")

    print(f"Wrote: {out_csv}")
    print(f"Wrote: {out_md_main}")
    print(f"Wrote: {out_md_placebo}")
    print(f"Wrote: {out_txt}")


if __name__ == "__main__":
    main()
